package com.stars.easyms.rest.initializer;

import com.stars.easyms.base.bean.UserInfo;
import com.stars.easyms.base.util.AESUtil;
import com.stars.easyms.monitor.handler.EasyMsRequestMappingHandlerMapping;
import com.stars.easyms.rest.annotation.EasyMsRestController;
import com.stars.easyms.rest.annotation.EasyMsRestMapping;
import com.stars.easyms.rest.constant.RestConstants;
import com.stars.easyms.rest.bean.RestInfo;
import com.stars.easyms.rest.enums.ServiceType;
import com.stars.easyms.rest.exception.RestRuntimeException;
import com.stars.easyms.rest.properties.EasyMsRestProperties;
import com.stars.easyms.rest.RestService;
import com.stars.easyms.base.util.*;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiOperation;
import net.sf.cglib.proxy.Enhancer;
import net.sf.cglib.proxy.MethodInterceptor;
import net.sf.cglib.proxy.MethodProxy;
import org.apache.commons.lang3.ClassUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.servlet.mvc.method.RequestMappingInfo;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

/**
 * REST服务注册类
 *
 * @author guoguifang
 * @date 2018-10-12 11:24
 * @since 1.0.0
 */
public final class EasyMsRestInitializer {

    private static final Logger logger = LoggerFactory.getLogger(EasyMsRestInitializer.class);

    private final EasyMsRestProperties easyMsRestProperties;

    private final EasyMsRequestMappingHandlerMapping easyMsRestHandlerMapping;

    private static final AtomicBoolean IS_REGISTER = new AtomicBoolean(false);

    private final RequestMappingInfo.BuilderConfiguration config = new RequestMappingInfo.BuilderConfiguration();

    private static final String REST_SERVICE_METHOD_NAME = "execute";

    private static final Map<String, Class<?>> API_MODEL_MAP = new ConcurrentHashMap<>();

    private Class<?> userInfoClass = UserInfo.class;

    private String globalSecret;

    private String globalIv;

    public EasyMsRestInitializer(EasyMsRestProperties easyMsRestProperties, EasyMsRequestMappingHandlerMapping easyMsRestHandlerMapping) {
        this.easyMsRestProperties = easyMsRestProperties;
        this.easyMsRestHandlerMapping = easyMsRestHandlerMapping;
    }

    public void init() {
        if (IS_REGISTER.compareAndSet(false, true)) {
            if (logger.isDebugEnabled()) {
                logger.debug("注册rest服务开始... ...");
            }
            try {
                this.register();
                if (logger.isDebugEnabled()) {
                    logger.debug("注册rest服务完成！");
                }
            } catch (Exception e) {
                logger.error("注册rest服务失败！", e);
                System.exit(1);
            }
        }
    }

    private void register() throws InstantiationException, IllegalAccessException {
        // 判断扫描路径是否设置
        Set<String> basePackageSet = easyMsRestProperties.getBasePackage();
        if (basePackageSet == null || basePackageSet.isEmpty()) {
            String springApplicationPackageName = SpringBootUtil.getSpringApplicationPackageName();
            if (springApplicationPackageName != null) {
                basePackageSet = new HashSet<>();
                basePackageSet.add(springApplicationPackageName);
            } else {
                throw new RestRuntimeException("Configure the rest scan path with [spring.rest.base-package]!");
            }
        }

        // 获取到用户信息的class名称并判断是否存在，如果不存在则默认使用默认的userInfo
        String userInfoClassName = easyMsRestProperties.getUserInfoClass();
        if (StringUtils.isNotBlank(userInfoClassName)) {
            try {
                userInfoClass = ClassUtils.getClass(EasyMsRestInitializer.class.getClassLoader(), userInfoClassName);
            } catch (ClassNotFoundException e) {
                // ignore
            }
        }

        // 获取rest的全局加密信息并校验
        if (easyMsRestProperties.isEncrypt()
                && StringUtils.isNotBlank(easyMsRestProperties.getSecret())
                && StringUtils.isNotBlank(easyMsRestProperties.getIv())) {
            checkRestSecretAndIv(globalSecret = easyMsRestProperties.getSecret(), globalIv = easyMsRestProperties.getIv());
        }

        // 获取所有加了EasyMsRestController注解的类和接口
        Set<Class<?>> restControllerClassSet = ReflectUtil.getAllClassByAnnotation(basePackageSet, EasyMsRestController.class);
        if (restControllerClassSet.isEmpty()) {
            logger.warn("No valid EasyMsRestController service class or interface was found！");
            return;
        }

        // 遍历所有EasyMsRestController类及接口
        for (Class<?> restControllerClass : restControllerClassSet) {

            // 再次判断，确定只扫描加了EasyMsRestController注解的类
            if (!restControllerClass.isAnnotationPresent(EasyMsRestController.class)) {
                continue;
            }

            // 获取class上的EasyMsRestController注解的requestMapping的值，先判断是否是el表达式，如果不是则使用本身的值
            EasyMsRestController easyMsRestController = restControllerClass.getAnnotation(EasyMsRestController.class);
            String classRequestMappingUri = StringFormatUtil.formatRequestMappingPath(
                    PropertyPlaceholderUtil.replace(easyMsRestController.path(), easyMsRestController.path()));

            // 创建restController对象
            Object restControllerInstance = newInstance(restControllerClass);

            // 遍历该类或接口中的所有方法
            Method[] restControllerMethods = restControllerClass.getDeclaredMethods();
            for (Method restControllerMethod : restControllerMethods) {
                registerRestControllerMethod(restControllerClass, restControllerMethod, restControllerInstance, classRequestMappingUri);
            }
        }

        // 注册完成后清除API_MODEL_MAP的值
        API_MODEL_MAP.clear();
    }

    @SuppressWarnings("unchecked")
    private void registerRestControllerMethod(Class<?> restControllerClass, Method restControllerMethod,
                                              Object restControllerInstance, String classRequestMappingUri) {
        // 只注册加了RestRequestMapping注解的方法
        if (!restControllerMethod.isAnnotationPresent(EasyMsRestMapping.class)) {
            return;
        }

        // 获取当前方法的名称
        String restControllerMethodName = restControllerMethod.getName();

        // 获取RestRequestMapping注解
        EasyMsRestMapping easyMsRestMapping = restControllerMethod.getDeclaredAnnotation(EasyMsRestMapping.class);

        // 获取rest接口加密信息,判断注解是否开启了加密
        String secret = null;
        String iv = null;
        String encryptRequestKey = null;
        String encryptResponseKey = null;
        boolean isEncrypt = easyMsRestMapping.encrypt();
        if (isEncrypt) {
            secret = globalSecret;
            iv = globalIv;
            boolean isNeedCheckRestSecret = false;
            if (StringUtils.isNotBlank(easyMsRestMapping.secret()) && !easyMsRestMapping.secret().equals(globalSecret)) {
                secret = easyMsRestMapping.secret();
                isNeedCheckRestSecret = true;
            }
            if (StringUtils.isNotBlank(easyMsRestMapping.iv()) && !easyMsRestMapping.iv().equals(globalIv)) {
                iv = easyMsRestMapping.iv();
                isNeedCheckRestSecret = true;
            }
            if (isNeedCheckRestSecret) {
                checkRestSecretAndIv(secret, iv);
            }
            if (secret != null && iv != null) {
                encryptRequestKey = easyMsRestMapping.encryptRequestKey();
                encryptRequestKey = StringUtils.isNotBlank(encryptRequestKey) ? encryptRequestKey : easyMsRestProperties.getEncryptRequestKey();
                encryptResponseKey = easyMsRestMapping.encryptResponseKey();
                encryptResponseKey = StringUtils.isNotBlank(encryptResponseKey) ? encryptResponseKey : easyMsRestProperties.getEncryptResponseKey();
            } else {
                isEncrypt = false;
            }
        }

        // 获取methodMap
        Map<String, String> methodMap = getMethodMap(restControllerClass, restControllerMethodName, easyMsRestMapping);

        // 校验restController方法是否有效
        checkRestControllerMethod(restControllerClass, restControllerMethod, methodMap);

        // 获取rest接口实现类，并校验是否有效
        Class<? extends RestService> restServiceClass = easyMsRestMapping.service();
        checkRestService(restControllerClass, restControllerMethod, restServiceClass);

        // 获取rest接口版本及rest服务类型
        String interfaceVersionPath = StringFormatUtil.formatRequestMappingPath(easyMsRestMapping.version());
        ServiceType serviceType = ServiceType.forCode(easyMsRestMapping.type());
        if (serviceType == null) {
            throw new RestRuntimeException("The restController[{}] method[{}] annotation[EasyMsRestMapping] type must be 'T' or 'Q'!",
                    restControllerClass.getName(), restControllerMethodName);
        }

        // 获取rest接口服务名称
        String interfaceName = easyMsRestMapping.name();
        if (StringUtils.isBlank(interfaceName)) {
            if (restControllerMethod.isAnnotationPresent(ApiOperation.class)) {
                interfaceName = restControllerMethod.getAnnotation(ApiOperation.class).value();
            } else {
                interfaceName = restControllerClass.getName() + "." + restControllerMethodName;
            }
        }

        // 获取rest接口服务的执行实例对象
        RestService restServiceInstance = ApplicationContextHolder.getBean(restServiceClass);
        if (restServiceInstance == null) {
            restServiceInstance = BeanUtil.registerBean((ConfigurableApplicationContext) ApplicationContextHolder.getApplicationContext(),
                    restServiceClass.getSimpleName(), restServiceClass);
        }

        // 获取rest接口是否可用暴露给外部使用
        boolean expose = easyMsRestMapping.expose();

        // 获取接口编号，若接口编号为空则默认取方法名，并生成格式为/classRequestMappingUri/interfaceCodePath/interfaceVersionPath的路径
        String[] easyMsRestMappingCodes = easyMsRestMapping.code();
        if (easyMsRestMappingCodes.length == 0) {
            easyMsRestMappingCodes = new String[]{restControllerMethodName};
        }
        for (String restRequestMappingCode : easyMsRestMappingCodes) {
            String methodRequestMappingPath = StringFormatUtil.formatRequestMappingPath(restRequestMappingCode);
            String interfaceCodePath = StringFormatUtil.formatRequestMappingPath(classRequestMappingUri + methodRequestMappingPath);
            String requestMappingPath = StringFormatUtil.formatRequestMappingPath(interfaceCodePath + interfaceVersionPath);

            // 获取rest信息bean并注册
            RestInfo restInfo = new RestInfo();
            restInfo.setName(interfaceName);
            restInfo.setCode(interfaceCodePath.substring(1));
            restInfo.setVersion(interfaceVersionPath.substring(1));
            restInfo.setRequestMappingPath(requestMappingPath);
            restInfo.setMethodMap(methodMap);
            restInfo.setMethodPriority(easyMsRestMapping.methodPriority());
            restInfo.setType(serviceType);
            restInfo.setRequestMethods(new RequestMethod[]{RequestMethod.POST});
            restInfo.setRestControllerClass(restControllerClass);
            restInfo.setRestControllerMethod(restControllerMethod);
            restInfo.setRestServiceClass(restServiceClass);
            restInfo.setRestServiceInstance(restServiceInstance);
            restInfo.setParameterType(restControllerMethod.getParameterTypes().length == 0 ? Void.class : restControllerMethod.getParameterTypes()[0]);
            restInfo.setReturnType(restControllerMethod.getReturnType() == void.class ? Void.class : restControllerMethod.getReturnType());
            restInfo.setAllowEndIsForwardSlash(easyMsRestProperties.isAllowEndIsForwardSlash() && easyMsRestMapping.allowEndIsForwardSlash());
            restInfo.setUserInfoClass(userInfoClass);
            restInfo.setExpose(expose);

            // 封装rest接口加密信息
            if (isEncrypt) {
                restInfo.setEncrypt(true);
                restInfo.setSecret(secret);
                restInfo.setIv(iv);
                restInfo.setEncryptRequestKey(encryptRequestKey);
                restInfo.setEncryptResponseKey(encryptResponseKey);
            }

            // 根据requestMappingPath获取requestMappingPathForRestInfo对象，requestMappingPath格式为/test/v1或/test/v2
            RequestMappingPathForRestInfo requestMappingPathForRestInfo = RestConstants.REQUEST_MAPPING_PATH_MAP.computeIfAbsent(requestMappingPath, RequestMappingPathForRestInfo::new);

            // 为requestMappingPath添加RestInfo映射信息，registerRequestMappingPath可以使相同的requestMappingPath但是method不同的注册进去，格式为/test/v1?method1=methodFor1&method2=methodFor2
            String registerRequestMappingPath = requestMappingPathForRestInfo.add(restInfo);
            restInfo.setRegisterRequestMappingPath(registerRequestMappingPath);

            // 注册RequestMapping映射关系
            try {
                RequestMappingInfo requestMappingInfo = RequestMappingInfo.paths(registerRequestMappingPath)
                        .methods(RequestMethod.POST).options(config).build();
                easyMsRestHandlerMapping.registerMapping(requestMappingInfo, restControllerInstance, restControllerMethod);
            } catch (Exception e) {
                throw new RestRuntimeException("The restController[{}] method[{}] register RequestMapping[{}] failure!", restControllerClass.getName(),
                        restControllerMethodName, registerRequestMappingPath, e);
            }

            // 记录注册信息
            logger.info("Register rest service: {} ==> {}.{}() success!", registerRequestMappingPath,
                    restControllerClass.getName(), restControllerMethodName);
            RestConstants.REGISTER_REQUEST_MAPPING_PATH_MAP.put(registerRequestMappingPath, restInfo);

            // 判断是否允许增加末尾'/'，默认不允许增加
            if (restInfo.isAllowEndIsForwardSlash()) {
                RestConstants.REQUEST_MAPPING_PATH_MAP.put(requestMappingPath + "/", requestMappingPathForRestInfo);
            }

            // 判断是否是默认版本，如果是默认版本则版本号可以不添加
            int requestMappingPathLastSlashIndex = requestMappingPath.lastIndexOf('/');
            if (RestConstants.DEFAULT_INTERFACE_VERSION.equalsIgnoreCase(requestMappingPath.substring(requestMappingPathLastSlashIndex + 1))) {
                String abbreviationRequestMappingPath = requestMappingPath.substring(0, requestMappingPathLastSlashIndex);
                restInfo.setAbbrRequestMappingPath(abbreviationRequestMappingPath);
                restInfo.setAbbrRegisterRequestMappingPath(registerRequestMappingPath.replace("/" + RestConstants.DEFAULT_INTERFACE_VERSION, ""));
                if (!RestConstants.REQUEST_MAPPING_PATH_MAP.containsKey(abbreviationRequestMappingPath)) {
                    RestConstants.REQUEST_MAPPING_PATH_MAP.put(abbreviationRequestMappingPath, requestMappingPathForRestInfo);
                    if (restInfo.isAllowEndIsForwardSlash()) {
                        RestConstants.REQUEST_MAPPING_PATH_MAP.put(abbreviationRequestMappingPath + "/", requestMappingPathForRestInfo);
                    }
                }
            }
        }
    }

    private Map<String, String> getMethodMap(Class<?> restControllerClass, String restControllerMethodName, EasyMsRestMapping easyMsRestMapping) {
        // 验证method和methodFor的个数是否一致
        String[] interfaceMethods = easyMsRestMapping.method();
        String[] interfaceMethodFors = easyMsRestMapping.methodFor();
        if (interfaceMethods.length != interfaceMethodFors.length) {
            throw new RestRuntimeException("The number of restController[{}] method[{}] annotation[EasyMsRestMapping] method and methodFor is inconsistent!",
                    restControllerClass.getName(), restControllerMethodName);
        }

        // 校验并封装method映射，这里需要排序，根据method和methodFor生成key，若存在则抛出异常，不存在则存入
        Map<String, String> methodMap = null;
        for (int i = 0; i < interfaceMethods.length; i++) {
            if (StringUtils.isNotBlank(interfaceMethods[i]) && StringUtils.isNotBlank(interfaceMethodFors[i])) {
                if (methodMap == null) {
                    methodMap = new TreeMap<>();
                }
                methodMap.put(interfaceMethods[i], interfaceMethodFors[i]);
            }
        }
        return methodMap;
    }

    /**
     * 校验restController的方法，包括校验输入参数和输出参数
     */
    private void checkRestControllerMethod(Class<?> restControllerClass, Method restControllerMethod, Map<String, String> methodMap) {
        // 校验方法的输入参数
        checkRestControllerMethodParameter(restControllerClass, restControllerMethod, methodMap);

        // 校验方法的输出参数
        checkRestControllerMethodReturn(restControllerClass, restControllerMethod);
    }

    /**
     * 校验restController的方法输入参数
     */
    private void checkRestControllerMethodParameter(Class<?> restControllerClass, Method restControllerMethod, Map<String, String> methodMap) {

        // 获取方法的输入参数，可以没有输入参数，也可只有一个输入参数
        Class<?>[] restControllerMethodParameterTypes = restControllerMethod.getParameterTypes();
        if (restControllerMethodParameterTypes.length > 1) {
            throw new RestRuntimeException("The restController[{}] method[{}] can only have no or one parameter!",
                    restControllerClass.getName(), restControllerMethod.getName());
        } else if (restControllerMethodParameterTypes.length == 1) {
            Class<?> restControllerMethodParameterType = restControllerMethodParameterTypes[0];
            if (restControllerMethodParameterType != Void.class) {
                // 如果methodMap不为空则判断方法参数是否包含该method字段，如果不包含抛出异常
                if (methodMap != null && !methodMap.isEmpty()) {
                    methodMap.keySet().forEach(method -> {
                        if (FieldUtils.getField(restControllerMethodParameterType, method, true) == null) {
                            throw new RestRuntimeException("The restController[{}] method[{}] parameterType[{}] does not contain field[{}]!",
                                    restControllerClass.getName(), restControllerMethod.getName(), restControllerMethodParameterType.getName(), method);
                        }
                    });
                }
                // 检查restController的方法输入参数的apiModel是否重复
                checkApiModelValue(restControllerClass, restControllerMethod, restControllerMethodParameterType);
            }
        }
    }

    /**
     * 校验restController的方法输出参数
     */
    private void checkRestControllerMethodReturn(Class<?> restControllerClass, Method restControllerMethod) {

        // 获取方法的输出参数
        Class<?> restControllerClassMethodReturnType = restControllerMethod.getReturnType();

        // 检查restController的方法输出参数的apiModel是否重复
        if (restControllerClassMethodReturnType != Void.class && restControllerClassMethodReturnType != void.class) {
            checkApiModelValue(restControllerClass, restControllerMethod, restControllerClassMethodReturnType);
        }
    }

    /**
     * 校验apiModal的值
     */
    private void checkApiModelValue(Class<?> restControllerClass, Method restControllerMethod, Class<?> checkClass) {
        ApiModel apiModel = AnnotationUtils.findAnnotation(checkClass, ApiModel.class);
        String apiModelValue = apiModel != null ? apiModel.value() : checkClass.getSimpleName();
        Class<?> existsClass = API_MODEL_MAP.get(apiModelValue);
        if (existsClass != null && existsClass != checkClass) {
            throw new RestRuntimeException("The restController[{}] method[{}] parameterType[{}] ApiModel value[{}] is already registered by the class[{}]!",
                    restControllerClass.getName(), restControllerMethod.getName(), checkClass.getName(),
                    apiModelValue, existsClass.getName());
        }
        API_MODEL_MAP.put(apiModelValue, checkClass);
    }

    private void checkRestService(Class<?> restControllerClass, Method restControllerMethod, Class<? extends RestService> restServiceClass) {
        // 获取所有方法名为REST_SERVICE_METHOD_NAME的方法集合
        Set<Method> restServiceMethods = Arrays.stream(restServiceClass.getDeclaredMethods())
                .filter(method -> REST_SERVICE_METHOD_NAME.equals(method.getName()))
                .collect(Collectors.toSet());
        if (restServiceMethods.isEmpty()) {
            throw new RestRuntimeException("The restService[{}] does not have method [{}]!", restServiceClass.getName(), REST_SERVICE_METHOD_NAME);
        }

        // 遍历方法集合并进行校验
        Class<?>[] restControllerMethodParameterTypes = restControllerMethod.getParameterTypes();
        Class<?> restControllerMethodReturnType = restControllerMethod.getReturnType();
        boolean checkSucc = false;
        for (Method restServiceMethod : restServiceMethods) {
            Class<?>[] restServiceMethodParameterTypes = restServiceMethod.getParameterTypes();
            // restService里的执行方法的输入参数必须只有一个，判断restController方法的参数个数，如果是0个，则判断restService的输入参数是否是Void，如果是一个，然后判断类型是否相同
            boolean checkParameterResult = restServiceMethodParameterTypes.length != 1
                    || (restControllerMethodParameterTypes.length == 0 && restServiceMethodParameterTypes[0] != Void.class)
                    || (restControllerMethodParameterTypes.length == 1 && restServiceMethodParameterTypes[0] != restControllerMethodParameterTypes[0]);
            if (checkParameterResult) {
                continue;
            }

            // restService里执行方法的返回参数如果是void或Void，则判断restController的返回值是否同样是void或Void，如果不是则判断是否两个返回类型是否相同，如果相同则认为校验成功
            Class<?> restServiceMethodReturnType = restServiceMethod.getReturnType();
            boolean checkRestServiceReturnTypeResult = ((restControllerMethodReturnType == void.class || restControllerMethodReturnType == Void.class)
                    && (restServiceMethodReturnType == void.class || restServiceMethodReturnType == Void.class))
                    || restServiceMethodReturnType == restControllerMethodReturnType;
            if (checkRestServiceReturnTypeResult) {
                checkSucc = true;
                break;
            }
        }

        // 如果校验不通过则抛出异常
        if (!checkSucc) {
            throw new RestRuntimeException("The restService[{}] method[{}] does not have the same parameterTypes or returnType as the restController[{}] method[{}]!",
                    restServiceClass.getName(), REST_SERVICE_METHOD_NAME, restControllerClass.getName(), restControllerMethod.getName());
        }
    }

    private void checkRestSecretAndIv(String secret, String iv) {
        try {
            AESUtil.getCipher(secret, iv, true);
        } catch (Exception e) {
            throw new RestRuntimeException("The secret[{}] or iv[{}] is invalid!", secret, iv, e);
        }
    }

    private Object newInstance(Class<?> clazz) throws IllegalAccessException, InstantiationException {
        Object instance;
        if (Modifier.isAbstract(clazz.getModifiers())) {
            Enhancer enhancer = new Enhancer();
            enhancer.setSuperclass(clazz);
            enhancer.setCallback((MethodInterceptor) (Object obj, Method method, Object[] params, MethodProxy methodProxy) -> null);
            instance = enhancer.create();
        } else {
            instance = clazz.newInstance();
        }
        return instance;
    }
}
